{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.1 决策树模型的基本原理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.1.3 决策树模型的代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "决策树模型既可以做分类分析（即预测分类变量值），也可以做回归分析（即预测连续变量值），分别对应的模型为分类决策树模型（DecisionTreeClassifier）及回归决策树模型（DecisionTreeRegressor）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**1.分类决策树模型（DecisionTreeClassifier）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier(random_state=0)\n",
    "model.fit(X, y)\n",
    "\n",
    "print(model.predict([[5, 5]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果要同时预测多个数据，则可以写成如下形式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 1]\n"
     ]
    }
   ],
   "source": [
    "print(model.predict([[5, 5], [7, 7], [9, 9]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**补充知识点：决策树可视化（供感兴趣的读者参考）**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过graphviz插件进行决策树可视化，graphviz插件的使用教程可以查看该文档:https://shimo.im/docs/Dcgw8H6WxgWrc8hq/ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: graphviz in d:\\programfiles\\anaconda3\\lib\\site-packages (0.20.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install graphviz "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Title: Tree Pages: 1 -->\n",
       "<svg width=\"275pt\" height=\"335pt\"\n",
       " viewBox=\"0.00 0.00 275.25 335.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 331)\">\n",
       "<title>Tree</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-331 271.25,-331 271.25,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-327 114,-327 114,-236.5 210.25,-236.5 210.25,-327\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-309.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[1] &lt;= 7.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-293.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.48</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-276.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-260.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 3]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-243.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"153.25,-200.5 57,-200.5 57,-110 153.25,-110 153.25,-200.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-183.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[0] &lt;= 2.0</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-166.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.444</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-150.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-133.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-117.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M141.71,-236.15C138.02,-228.11 134.14,-219.63 130.34,-211.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"133.57,-209.96 126.22,-202.33 127.2,-212.88 133.57,-209.96\"/>\n",
       "<text text-anchor=\"middle\" x=\"116.88\" y=\"-219.33\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"267.25,-192.25 171,-192.25 171,-118.25 267.25,-118.25 267.25,-192.25\"/>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-174.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-158.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-141.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2]</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-125.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.54,-236.15C187.53,-225.26 192.88,-213.58 197.9,-202.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"200.98,-204.28 201.96,-193.73 194.62,-201.37 200.98,-204.28\"/>\n",
       "<text text-anchor=\"middle\" x=\"211.31\" y=\"-210.74\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"96.25,-74 0,-74 0,0 96.25,0 96.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M83.25,-109.64C79.23,-101.44 75.03,-92.87 70.98,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"74.18,-83.19 66.64,-75.75 67.9,-86.27 74.18,-83.19\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-74 114,-74 114,0 210.25,0 210.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 0]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>1&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M127,-109.64C131.02,-101.44 135.22,-92.87 139.27,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"142.35,-86.27 143.61,-75.75 136.07,-83.19 142.35,-86.27\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x280bf9bd6d0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1.如果不用显示中文，那么通过如下代码即可。安装graphviz稍微有些麻烦，可以参考上面的参考链接。\n",
    "from sklearn.tree import export_graphviz\n",
    "import graphviz\n",
    "\n",
    "import os  # 以下两行是环境变量配置，运行一次即可，相关知识点可以参考5.2.3节\n",
    "os.environ['PATH'] = os.pathsep + r'D:\\ProgramFiles\\Graphviz\\bin'  #需要修改此路径\n",
    "\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph  # 通过graph.render('决策树可视化')可在代码所在文件夹生成决策树可视化PDF文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**补充知识点：random_state参数的作用解释**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Title: Tree Pages: 1 -->\n",
       "<svg width=\"275pt\" height=\"335pt\"\n",
       " viewBox=\"0.00 0.00 275.25 335.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 331)\">\n",
       "<title>Tree</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-331 271.25,-331 271.25,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-327 114,-327 114,-236.5 210.25,-236.5 210.25,-327\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-309.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[0] &lt;= 6.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-293.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.48</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-276.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-260.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 3]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-243.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"153.25,-200.5 57,-200.5 57,-110 153.25,-110 153.25,-200.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-183.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[1] &lt;= 3.0</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-166.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.444</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-150.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-133.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-117.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M141.71,-236.15C138.02,-228.11 134.14,-219.63 130.34,-211.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"133.57,-209.96 126.22,-202.33 127.2,-212.88 133.57,-209.96\"/>\n",
       "<text text-anchor=\"middle\" x=\"116.88\" y=\"-219.33\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"267.25,-192.25 171,-192.25 171,-118.25 267.25,-118.25 267.25,-192.25\"/>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-174.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-158.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-141.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2]</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-125.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.54,-236.15C187.53,-225.26 192.88,-213.58 197.9,-202.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"200.98,-204.28 201.96,-193.73 194.62,-201.37 200.98,-204.28\"/>\n",
       "<text text-anchor=\"middle\" x=\"211.31\" y=\"-210.74\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"96.25,-74 0,-74 0,0 96.25,0 96.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M83.25,-109.64C79.23,-101.44 75.03,-92.87 70.98,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"74.18,-83.19 66.64,-75.75 67.9,-86.27 74.18,-83.19\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-74 114,-74 114,0 210.25,0 210.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 0]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>1&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M127,-109.64C131.02,-101.44 135.22,-92.87 139.27,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"142.35,-86.27 143.61,-75.75 136.07,-83.19 142.35,-86.27\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x280bee79090>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier()  # 不设置random_state参数\n",
    "model.fit(X, y)\n",
    "\n",
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Title: Tree Pages: 1 -->\n",
       "<svg width=\"275pt\" height=\"335pt\"\n",
       " viewBox=\"0.00 0.00 275.25 335.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 331)\">\n",
       "<title>Tree</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-331 271.25,-331 271.25,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-327 114,-327 114,-236.5 210.25,-236.5 210.25,-327\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-309.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[1] &lt;= 7.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-293.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.48</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-276.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-260.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 3]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-243.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"153.25,-200.5 57,-200.5 57,-110 153.25,-110 153.25,-200.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-183.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[1] &lt;= 3.0</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-166.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.444</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-150.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-133.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"105.12\" y=\"-117.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M141.71,-236.15C138.02,-228.11 134.14,-219.63 130.34,-211.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"133.57,-209.96 126.22,-202.33 127.2,-212.88 133.57,-209.96\"/>\n",
       "<text text-anchor=\"middle\" x=\"116.88\" y=\"-219.33\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"267.25,-192.25 171,-192.25 171,-118.25 267.25,-118.25 267.25,-192.25\"/>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-174.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-158.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-141.95\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2]</text>\n",
       "<text text-anchor=\"middle\" x=\"219.12\" y=\"-125.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.54,-236.15C187.53,-225.26 192.88,-213.58 197.9,-202.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"200.98,-204.28 201.96,-193.73 194.62,-201.37 200.98,-204.28\"/>\n",
       "<text text-anchor=\"middle\" x=\"211.31\" y=\"-210.74\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"96.25,-74 0,-74 0,0 96.25,0 96.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
       "<text text-anchor=\"middle\" x=\"48.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 1</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M83.25,-109.64C79.23,-101.44 75.03,-92.87 70.98,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"74.18,-83.19 66.64,-75.75 67.9,-86.27 74.18,-83.19\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210.25,-74 114,-74 114,0 210.25,0 210.25,-74\"/>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-56.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [2, 0]</text>\n",
       "<text text-anchor=\"middle\" x=\"162.12\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 0</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>1&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M127,-109.64C131.02,-101.44 135.22,-92.87 139.27,-84.62\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"142.35,-86.27 143.61,-75.75 136.07,-83.19 142.35,-86.27\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x280beb78390>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier()  # 不设置random_state参数\n",
    "model.fit(X, y)\n",
    "\n",
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到它们的节点划分方式是不同的，这样会导致同样的数据的预测结果有所不同，例如数据[7, 7]在左边的决策树中会被预测为类别0，而在右边的决策树中会被预测为类别1。\n",
    "\n",
    "此时有的读者就会有疑问了，为什么模型训练后会产生两颗不同的树呢，哪棵树是正确的呢？其实两颗树都是正确的，出现这种情况的原因，是因为根据“X[1]<=7”或者“X[0]<=6”进行节点划分时产生的基尼系数下降是一样的（都是0.48 - (0.6*0.444 + 0.4*0) = 0.2136），所以无论以哪种形式进行节点划分都是合理的。产生这一现象的原因大程度是因为数据量较少所以容易产生不同划分方式产生的基尼系数下降是一样的情况，当数据量较大时出现该现象的几率则较小。\n",
    "\n",
    "总的来说，对于同一个模型，不同的划分方式可能会导致最后的预测结果会有所不同，设置random_state参数（可以设置成0，也可以设置成1或123等任意数字）则能保证每次的划分方式都是一致的，使得每次运行的结果相同。这个概念对于初学者来说还是挺重要的，因为初学者往往会发现怎么每次运行同一个模型出来的结果都不一样，从而一头雾水，如果出现这种情况，那么设置一下random_state参数即可。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2.回归决策树模型（DecisionTreeRegressor）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.5]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 2, 3, 4, 5]\n",
    "\n",
    "model = DecisionTreeRegressor(max_depth=2, random_state=0)\n",
    "model.fit(X, y)\n",
    "\n",
    "print(model.predict([[9, 9]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Title: Tree Pages: 1 -->\n",
       "<svg width=\"644pt\" height=\"286pt\"\n",
       " viewBox=\"0.00 0.00 644.12 285.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 281.5)\">\n",
       "<title>Tree</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-281.5 640.12,-281.5 640.12,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"385,-277.5 243,-277.5 243,-203.5 385,-203.5 385,-277.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"314\" y=\"-260.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[1] &lt;= 5.0</text>\n",
       "<text text-anchor=\"middle\" x=\"314\" y=\"-243.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 2.0</text>\n",
       "<text text-anchor=\"middle\" x=\"314\" y=\"-227.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
       "<text text-anchor=\"middle\" x=\"314\" y=\"-210.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 3.0</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"303.12,-167.5 152.88,-167.5 152.88,-93.5 303.12,-93.5 303.12,-167.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"228\" y=\"-150.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[0] &lt;= 2.0</text>\n",
       "<text text-anchor=\"middle\" x=\"228\" y=\"-133.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.25</text>\n",
       "<text text-anchor=\"middle\" x=\"228\" y=\"-117.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"228\" y=\"-100.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 1.5</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M285.12,-203.24C278.32,-194.69 270.98,-185.47 263.9,-176.59\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"266.86,-174.68 257.89,-169.04 261.38,-179.04 266.86,-174.68\"/>\n",
       "<text text-anchor=\"middle\" x=\"254.13\" y=\"-187.64\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"479.25,-167.5 320.75,-167.5 320.75,-93.5 479.25,-93.5 479.25,-167.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"400\" y=\"-150.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x[0] &lt;= 6.0</text>\n",
       "<text text-anchor=\"middle\" x=\"400\" y=\"-133.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.667</text>\n",
       "<text text-anchor=\"middle\" x=\"400\" y=\"-117.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
       "<text text-anchor=\"middle\" x=\"400\" y=\"-100.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 4.0</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M342.88,-203.24C349.68,-194.69 357.02,-185.47 364.1,-176.59\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"366.62,-179.04 370.11,-169.04 361.14,-174.68 366.62,-179.04\"/>\n",
       "<text text-anchor=\"middle\" x=\"373.87\" y=\"-187.64\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"142,-57.5 0,-57.5 0,0 142,0 142,-57.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"71\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"71\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"71\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 1.0</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M170.88,-93.21C155.77,-83.61 139.53,-73.29 124.61,-63.81\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"126.63,-60.95 116.31,-58.54 122.88,-66.86 126.63,-60.95\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"302,-57.5 160,-57.5 160,0 302,0 302,-57.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"231\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"231\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"231\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 2.0</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>1&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M229.09,-93.21C229.33,-85.35 229.58,-77 229.82,-69.03\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"233.31,-69.34 230.11,-59.24 226.31,-69.13 233.31,-69.34\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>5</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"468,-57.5 326,-57.5 326,0 468,0 468,-57.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"397\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.0</text>\n",
       "<text text-anchor=\"middle\" x=\"397\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1</text>\n",
       "<text text-anchor=\"middle\" x=\"397\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 3.0</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;5 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>4&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M398.91,-93.21C398.67,-85.35 398.42,-77 398.18,-69.03\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"401.69,-69.13 397.89,-59.24 394.69,-69.34 401.69,-69.13\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>6</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"636.12,-57.5 485.88,-57.5 485.88,0 636.12,0 636.12,-57.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"561\" y=\"-40.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.25</text>\n",
       "<text text-anchor=\"middle\" x=\"561\" y=\"-23.7\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 2</text>\n",
       "<text text-anchor=\"middle\" x=\"561\" y=\"-7.2\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 4.5</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;6 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>4&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M458.58,-93.21C474.14,-83.56 490.89,-73.19 506.24,-63.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"507.65,-66.92 514.3,-58.68 503.96,-60.97 507.65,-66.92\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x280befb7b10>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None)  # 回归决策树就没有class分类参数了\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
